
###############################
# class for single heatmap
#


# the layout of the heatmap is 7 x 9

# == title
# Class for a single heatmap
#
# == details
# The components for a single heamtap are placed into a 9 x 7 layout:
#
#          +------+ (1)
#          +------+ (2)
#          +------+ (3)
#          +------+ (4)
#    +-+-+-+------+-+-+-+
#    |1|2|3| 4(5) |5|6|7|
#    +-+-+-+------+-+-+-+
#          +------+ (6)
#          +------+ (7)
#          +------+ (8)
#          +------+ (9)
#
# From top to bottom in column 4, the regions are:
#
# - title which is put on the top of the heatmap, graphics are drawn by `draw_title,Heatmap-method`.
# - column cluster on the top, graphics are drawn by `draw_dend,Heatmap-method`.
# - column annotation on the top, graphics are drawn by `draw_annotation,Heatmap-method`.
# - column names on the top, graphics are drawn by `draw_dimnames,Heatmap-method`.
# - heatmap body, graphics are drawn by `draw_heatmap_body,Heatmap-method`.
# - column names on the bottom, graphics are drawn by `draw_dimnames,Heatmap-method`.
# - column annotation on the bottom, graphics are drawn by `draw_annotation,Heatmap-method`.
# - column cluster on the bottom, graphics are drawn by `draw_dend,Heatmap-method`.
# - title on the bottom, graphics are drawn by `draw_title,Heatmap-method`.
# 
# From left to right in row 5, the regions are:
#
# - title which is put in the left of the heatmap, graphics are drawn by `draw_title,Heatmap-method`.
# - row cluster on the left, graphics are drawn by `draw_dend,Heatmap-method`.
# - row names on the left, graphics are drawn by `draw_dimnames,Heatmap-method`.
# - heatmap body
# - row names on the right, graphics are drawn by `draw_dimnames,Heatmap-method`.
# - row cluster on the right, graphics are drawn by `draw_dend,Heatmap-method`.
# - title on the right, graphics are drawn by `draw_title,Heatmap-method`.
#
# The `Heatmap-class` is not responsible for heatmap legend and annotation legends. The `draw,Heatmap-method` method
# will construct a `HeatmapList-class` object which only contains one single heatmap
# and call `draw,HeatmapList-method` to make a complete heatmap.
#
# == methods
# The `Heatmap-class` provides following methods:
#
# - `Heatmap`: constructor method.
# - `draw,Heatmap-method`: draw a single heatmap.
# - `add_heatmap,Heatmap-method` append heatmaps and row annotations to a list of heatmaps.
# - `row_order,HeatmapList-method`: get order of rows
# - `column_order,HeatmapList-method`: get order of columns
# - `row_dend,HeatmapList-method`: get row dendrograms
# - `column_dend,HeatmapList-method`: get column dendrograms
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
Heatmap = setClass("Heatmap",
    slots = list(
        name = "character",

        matrix = "matrix",  # one or more matrix which are spliced by rows
        matrix_param = "list",
        matrix_color_mapping = "ANY",
        matrix_color_mapping_param = "ANY",

        row_title = "ANY",
        row_title_rot = "numeric",
        row_title_just = "numeric",
        row_title_param = "list",
        column_title = "ANY",
        column_title_param = "list",
        column_title_rot = "numeric",
        column_title_just = "numeric",

        row_dend_list = "list", # one or more row clusters
        row_dend_param = "list", # parameters for row cluster
        row_order_list = "list",
        row_order = "numeric",

        column_dend = "ANY",
        column_dend_param = "list", # parameters for column cluster
        column_order = "numeric",

        row_names_param = "list",
        column_names_param = "list",

        top_annotation = "ANY", # NULL or a `HeatmapAnnotation` object
        top_annotation_param = "list",

        bottom_annotation = "ANY",
        bottom_annotation_param = "list",

        heatmap_param = "list",

        layout = "list"
    ),
    contains = "AdditiveUnit"
)



# == title
# Constructor method for Heatmap class
#
# == param
# -matrix a matrix. Either numeric or character. If it is a simple vector, it will be
#         converted to a one-column matrix.
# -col a vector of colors if the color mapping is discrete or a color mapping 
#      function if the matrix is continuous numbers (should be generated by `circlize::colorRamp2`. If the matrix is continuous,
#      the value can also be a vector of colors so that colors will be interpolated. Pass to `ColorMapping`.
# -name name of the heatmap. The name is used as the title of the heatmap legend.
# -na_col color for ``NA`` values.
# -rect_gp graphic parameters for drawing rectangles (for heatmap body).
# -color_space the color space in which colors are interpolated. Only used if ``matrix`` is numeric and 
#            ``col`` is a vector of colors. Pass to `circlize::colorRamp2`.
# -cell_fun self-defined function to add graphics on each cell. Seven parameters will be passed into 
#           this function: ``i``, ``j``, ``x``, ``y``, ``width``, ``height``, ``fill`` which are row index,
#           column index in ``matrix``, coordinate of the middle points in the heatmap body viewport,
#           the width and height of the cell and the filled color. ``x``, ``y``, ``width`` and ``height`` are all `grid::unit` objects.
# -row_title title on row.
# -row_title_side will the title be put on the left or right of the heatmap?
# -row_title_gp graphic parameters for drawing text.
# -row_title_rot rotation of row titles. Only 0, 90, 270 are allowed to set.
# -column_title title on column.
# -column_title_side will the title be put on the top or bottom of the heatmap?
# -column_title_gp graphic parameters for drawing text.
# -column_title_rot rotation of column titles. Only 0, 90, 270 are allowed to set.
# -cluster_rows If the value is a logical, it means whether make cluster on rows. The value can also
#               be a `stats::hclust` or a `stats::dendrogram` that already contains clustering information.
#               This means you can use any type of clustering methods and render the `stats::dendrogram`
#               object with self-defined graphic settings.
# -clustering_distance_rows it can be a pre-defined character which is in 
#                ("euclidean", "maximum", "manhattan", "canberra", "binary", 
#                "minkowski", "pearson", "spearman", "kendall"). It can also be a function.
#                If the function has one argument, the input argument should be a matrix and 
#                the returned value should be a `stats::dist` object. If the function has two arguments,
#                the input arguments are two vectors and the function calculates distance between these
#                two vectors.
# -clustering_method_rows method to make cluster, pass to `stats::hclust`.
# -row_dend_side should the row cluster be put on the left or right of the heatmap?
# -row_dend_width width of the row cluster, should be a `grid::unit` object.
# -show_row_dend whether show row clusters. 
# -row_dend_gp graphics parameters for drawing lines. If users already provide a `stats::dendrogram`
#                object with edges rendered, this argument will be ignored.
# -row_dend_reorder apply reordering on rows. The value can be a logical value or a vector which contains weight 
#               which is used to reorder rows
# -row_hclust_side deprecated, use ``row_dend_side`` instead
# -row_hclust_width deprecated, use ``row_dend_width`` instead
# -show_row_hclust deprecated, use ``show_row_dend`` instead
# -row_hclust_gp deprecated, use ``row_dend_gp`` instead
# -row_hclust_reorder deprecated, use ``row_dend_reorder`` instead
# -cluster_columns whether make cluster on columns. Same settings as ``cluster_rows``.
# -clustering_distance_columns same setting as ``clustering_distance_rows``.
# -clustering_method_columns method to make cluster, pass to `stats::hclust`.
# -column_dend_side should the column cluster be put on the top or bottom of the heatmap?
# -column_dend_height height of the column cluster, should be a `grid::unit` object.
# -show_column_dend whether show column clusters.
# -column_dend_gp graphic parameters for drawling lines. Same settings as ``row_dend_gp``.
# -column_dend_reorder apply reordering on columns. The value can be a logical value or a vector which contains weight 
#               which is used to reorder columns
# -column_hclust_side deprecated, use ``column_dend_side`` instead
# -column_hclust_height deprecated, use ``column_dend_height`` instead
# -show_column_hclust deprecated, use ``show_column_dend`` instead
# -column_hclust_gp deprecated, use ``column_dend_gp`` instead
# -column_hclust_reorder deprecated, use ``column_dend_reorder`` instead
# -row_order order of rows. It makes it easy to adjust row order for a list of heatmaps if this heatmap 
#      is selected as the main heatmap. Manually setting row order should turn off clustering
# -column_order order of column. It makes it easy to adjust column order for both matrix and column annotations.
# -row_names_side should the row names be put on the left or right of the heatmap?
# -show_row_names whether show row names.
# -row_names_max_width maximum width of row names viewport. Because some times row names can be very long, it is not reasonable
#                      to show them all.
# -row_names_gp graphic parameters for drawing text.
# -column_names_side should the column names be put on the top or bottom of the heatmap?
# -column_names_max_height maximum height of column names viewport.
# -show_column_names whether show column names.
# -column_names_gp graphic parameters for drawing text.
# -top_annotation a `HeatmapAnnotation` object which contains a list of annotations.
# -top_annotation_height total height of the column annotations on the top.
# -bottom_annotation a `HeatmapAnnotation` object.
# -bottom_annotation_height total height of the column annotations on the bottom.
# -km do k-means clustering on rows. If the value is larger than 1, the heatmap will be split by rows according to the k-means clustering.
#     For each row-clusters, hierarchical clustering is still applied with parameters above.
# -km_title row title for each cluster when ``km`` is set. It must a text with format of ".*\%i.*" where "\%i" is replaced by the index of the cluster.
# -split a vector or a data frame by which the rows are split. But if ``cluster_rows`` is a clustering object, ``split`` can be a single number
#        indicating rows are to be split according to the split on the tree.
# -gap gap between row-slices if the heatmap is split by rows, should be `grid::unit` object. If it is a vector, the order corresponds
#   to top to bottom in the heatmap
# -combined_name_fun if the heatmap is split by rows, how to make a combined row title for each slice?
#                 The input parameter for this function is a vector which contains level names under each column in ``split``.
# -width the width of the single heatmap, should be a fixed `grid::unit` object. It is used for the layout when the heatmap
#        is appended to a list of heatmaps.
# -show_heatmap_legend whether show heatmap legend?
# -heatmap_legend_param a list contains parameters for the heatmap legend. See `color_mapping_legend,ColorMapping-method` for all available parameters.
# -use_raster whether render the heatmap body as a raster image. It helps to reduce file size when the matrix is huge. Note if ``cell_fun``
#       is set, ``use_raster`` is enforced to be ``FALSE``.
# -raster_device graphic device which is used to generate the raster image
# -raster_quality a value set to larger than 1 will improve the quality of the raster image.
# -raster_device_param a list of further parameters for the selected graphic device
#
# == details
# The initialization function only applies parameter checking and fill values to each slot with proper ones.
# Then it will be ready for clustering and layout.
# 
# Following methods can be applied on the `Heatmap-class` object:
#
# - `show,Heatmap-method`: draw a single heatmap with default parameters
# - `draw,Heatmap-method`: draw a single heatmap.
# - `add_heatmap,Heatmap-method` append heatmaps and row annotations to a list of heatmaps.
#
# The constructor function pretends to be a high-level graphic function because the ``show`` method
# of the `Heatmap-class` object actually plots the graphics.
#
# == value
# A `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
Heatmap = function(matrix, col, name, 
    na_col = "grey", 
    color_space = "LAB",
    rect_gp = gpar(col = NA), 
    cell_fun = NULL,
    row_title = character(0), 
    row_title_side = c("left", "right"), 
    row_title_gp = gpar(fontsize = 14), 
    row_title_rot = switch(row_title_side[1], "left" = 90, "right" = 270),
    column_title = character(0), 
    column_title_side = c("top", "bottom"), 
    column_title_gp = gpar(fontsize = 14), 
    column_title_rot = 0,
    cluster_rows = TRUE, 
    clustering_distance_rows = "euclidean",
    clustering_method_rows = "complete", 
    row_dend_side = c("left", "right"),
    row_dend_width = unit(10, "mm"), 
    show_row_dend = TRUE, 
    row_dend_reorder = TRUE,
    row_dend_gp = gpar(), 
    row_hclust_side = row_dend_side,
    row_hclust_width = row_dend_width, 
    show_row_hclust = show_row_dend, 
    row_hclust_reorder = row_dend_reorder,
    row_hclust_gp = row_dend_gp, 
    cluster_columns = TRUE, 
    clustering_distance_columns = "euclidean", 
    clustering_method_columns = "complete",
    column_dend_side = c("top", "bottom"), 
    column_dend_height = unit(10, "mm"), 
    show_column_dend = TRUE, 
    column_dend_gp = gpar(), 
    column_dend_reorder = TRUE,
    column_hclust_side = column_dend_side, 
    column_hclust_height = column_dend_height, 
    show_column_hclust = show_column_dend, 
    column_hclust_gp = column_dend_gp, 
    column_hclust_reorder = column_dend_reorder,
    row_order = NULL, 
    column_order = NULL,
    row_names_side = c("right", "left"), 
    show_row_names = TRUE, 
    row_names_max_width = default_row_names_max_width(), 
    row_names_gp = gpar(fontsize = 12), 
    column_names_side = c("bottom", "top"), 
    show_column_names = TRUE, 
    column_names_max_height = default_column_names_max_height(), 
    column_names_gp = gpar(fontsize = 12),
    top_annotation = new("HeatmapAnnotation"),
    top_annotation_height = top_annotation@size,
    bottom_annotation = new("HeatmapAnnotation"),
    bottom_annotation_height = bottom_annotation@size,
    km = 1, 
    km_title = "cluster%i",
    split = NULL, 
    gap = unit(1, "mm"), 
    combined_name_fun = function(x) paste(x, collapse = "/"),
    width = NULL, 
    show_heatmap_legend = TRUE,
    heatmap_legend_param = list(title = name),
    use_raster = FALSE, 
    raster_device = c("png", "jpeg", "tiff", "CairoPNG", "CairoJPEG", "CairoTIFF"),
    raster_quality = 2,
    raster_device_param = list()) {

    # re-define some of the argument values according to global settings
    called_args = names(as.list(match.call())[-1])
    e = environment()
    for(opt_name in c("row_names_gp", "column_names_gp", "row_title_gp", "column_title_gp")) {
        opt_name2 = paste0("heatmap_", opt_name)
        if(! opt_name %in% called_args) { # if this argument is not called
            if(!is.null(ht_global_opt(opt_name2))) {
                assign(opt_name, ht_global_opt(opt_name2), envir = e)
            }
        }
    }

    for(ca in called_args) {
        if(ca %in% c("row_hclust_side", "row_hclust_width", "show_row_hclust", "row_hclust_reorder", "row_hclust_gp",
                     "column_hclust_side", "column_hclust_height", "show_column_hclust", "column_hclust_gp", "column_hclust_reorder")) {
            ca_new = gsub("hclust", "dend", ca)
            if(!ca_new %in% called_args) {
                assign(ca_new, get(ca))
            }
            warning(paste0("'", ca, "' is deprecated in the future, use '", ca_new, "' instead."))
        }
    }
   
    if("heatmap_legend_param" %in% called_args) {
        for(opt_name in setdiff(c("title_gp", "title_position", "labels_gp", "grid_width", "grid_height", "grid_border"), names(heatmap_legend_param))) {
            opt_name2 = paste0("heatmap_legend_", opt_name)
            if(!is.null(ht_global_opt(opt_name2)))
                heatmap_legend_param[[opt_name]] = ht_global_opt(opt_name2)
        }
    } else {
        for(opt_name in c("title_gp", "title_position", "labels_gp", "grid_width", "grid_height", "grid_border")) {
            opt_name2 = paste0("heatmap_legend_", opt_name)
            if(!is.null(ht_global_opt(opt_name2)))
                heatmap_legend_param[[opt_name]] = ht_global_opt(opt_name2)
        }
    }

    .Object = new("Heatmap")

    .Object@heatmap_param$width = width
    .Object@heatmap_param$show_heatmap_legend = show_heatmap_legend
    .Object@heatmap_param$use_raster = use_raster
    .Object@heatmap_param$raster_device = match.arg(raster_device)[1]
    .Object@heatmap_param$raster_quality = raster_quality
    .Object@heatmap_param$raster_device_param = raster_device_param

    if(is.data.frame(matrix)) {
        matrix = as.matrix(matrix)
    }
    if(!is.matrix(matrix)) {
        if(is.atomic(matrix)) {
            rn = names(matrix)
            matrix = matrix(matrix, ncol = 1)
            if(!is.null(rn)) rownames(matrix) = rn
            if(!missing(name)) colnames(matrix) = name
        } else {
            stop("If data is not a matrix, it should be a simple vector.")
        }
    }

    if(is.null(width)) {
        .Object@heatmap_param$width = ncol(matrix)
    }

    if(ncol(matrix) == 0) {
        .Object@heatmap_param$show_heatmap_legend = FALSE
        .Object@heatmap_param$width = unit(0, "mm")
    }

    if(ncol(matrix) == 0 || nrow(matrix) == 0) {
        if(!inherits(cluster_columns, c("dendrogram", "hclust"))) {
            cluster_columns = FALSE
            show_column_dend = FALSE
        }
        if(!inherits(cluster_rows, c("dendrogram", "hclust"))) {
            cluster_rows = FALSE
            show_row_dend = FALSE
        }
        km = 1
    }
    if(ncol(matrix) == 1) {
        if(!inherits(cluster_columns, c("dendrogram", "hclust"))) {
            cluster_columns = FALSE
            show_column_dend = FALSE
        }
    }
    if(nrow(matrix) == 1) {
        if(!inherits(cluster_rows, c("dendrogram", "hclust"))) {
            cluster_rows = FALSE
            show_row_dend = FALSE
        }
        km = 1
    }
    if(is.character(matrix)) {
        called_args = names(match.call()[-1])
        if("clustering_distance_rows" %in% called_args) {
        } else if(inherits(cluster_rows, c("dendrogram", "hclust"))) {
        } else {
            cluster_rows = FALSE
            show_row_dend = FALSE
        }
        row_dend_reorder = FALSE
        if("clustering_distance_columns" %in% called_args) {
        } else if(inherits(cluster_columns, c("dendrogram", "hclust"))) {
        } else {
            cluster_columns = FALSE
            show_column_dend = FALSE
        }
        column_dend_reorder = FALSE
        km = 1
    }
    .Object@matrix = matrix
    .Object@matrix_param$km = km
    .Object@matrix_param$km_title = km_title
    .Object@matrix_param$gap = gap
    if(!is.null(split)) {
        if(inherits(cluster_rows, c("dendrogram", "hclust"))) {
            .Object@matrix_param$split = split
        } else {
            if(identical(cluster_rows, TRUE) && is.numeric(split) && length(split) == 1) {

            } else {
                if(!is.data.frame(split)) split = data.frame(split)
                if(nrow(split) != nrow(matrix)) {
                    stop("Length or number of rows of `split` should be same as rows in `matrix`.")
                }
            }
        }
    }
    .Object@matrix_param$split = split
    .Object@matrix_param$gp =check_gp(rect_gp)
    .Object@matrix_param$cell_fun = cell_fun
    
    if(missing(name)) {
        name = paste0("matrix_", get_heatmap_index() + 1)
        increase_heatmap_index()
    }
    .Object@name = name

    if(ncol(matrix) == 1 && is.null(colnames(matrix))) {
        colnames(matrix) = name
        .Object@matrix = matrix
    }

    # color for main matrix
    if(ncol(matrix) > 0 && nrow(matrix) > 0) {
        if(missing(col)) {
            col = default_col(matrix, main_matrix = TRUE)
        }
        if(is.function(col)) {
            .Object@matrix_color_mapping = ColorMapping(col_fun = col, name = name, na_col = na_col)
        } else {
            if(is.null(names(col))) {
                if(length(col) == length(unique(as.vector(matrix)))) {
                    names(col) = sort(unique(as.vector(matrix)))
                    .Object@matrix_color_mapping = ColorMapping(colors = col, name = name, na_col = na_col)
                } else if(is.numeric(matrix)) {
                    col = colorRamp2(seq(min(matrix, na.rm = TRUE), max(matrix, na.rm = TRUE), length = length(col)),
                                     col, space = color_space)
                    .Object@matrix_color_mapping = ColorMapping(col_fun = col, name = name, na_col = na_col)
                } else {
                    stop("`col` should have names to map to values in `mat`.")
                }
            } else {
                col = col[intersect(c(names(col), "_NA_"), as.character(matrix))]
                .Object@matrix_color_mapping = ColorMapping(colors = col, name = name, na_col = na_col)
            }
        }
        .Object@matrix_color_mapping_param = heatmap_legend_param
    }
    
    if(length(row_title) == 0) {
        row_title = character(0)
    } else if(!inherits(row_title, c("expression", "call"))) {
            if(is.na(row_title)) {
            row_title = character(0)
        } else if(row_title == "") {
            row_title = character(0)
        }
    }
    .Object@row_title = row_title
    .Object@row_title_rot = row_title_rot %% 360
    .Object@row_title_param$side = match.arg(row_title_side)[1]
    .Object@row_title_param$gp = check_gp(row_title_gp)  # if the number of settings is same as number of row-splits, gp will be adjusted by `make_row_dend`
    .Object@row_title_param$combined_name_fun = combined_name_fun
    .Object@row_title_just = get_text_just(rot = row_title_rot, side = .Object@row_title_param$side)

    if(length(column_title) == 0) {
        column_title = character(0)
    } else if(!inherits(column_title, c("expression", "call"))) {
            if(is.na(column_title)) {
            column_title = character(0)
        } else if(column_title == "") {
            column_title = character(0)
        }
    }
    .Object@column_title = column_title
    .Object@column_title_rot = column_title_rot %% 360
    .Object@column_title_param$side = match.arg(column_title_side)[1]
    .Object@column_title_param$gp = check_gp(column_title_gp)
    .Object@column_title_just = get_text_just(rot = column_title_rot, side = .Object@column_title_param$side)

    if(is.null(rownames(matrix))) {
        show_row_names = FALSE
    }
    .Object@row_names_param$side = match.arg(row_names_side)[1]
    .Object@row_names_param$show = show_row_names
    .Object@row_names_param$gp = check_gp(row_names_gp)
    default_row_names_max_width = function() {
        min(unit.c(unit(6, "cm")), max_text_width(rownames(matrix), gp = .Object@row_names_param$gp))
    }
    .Object@row_names_param$max_width = row_names_max_width + unit(2, "mm")

    if(is.null(colnames(matrix))) {
        show_column_names = FALSE
    }
    .Object@column_names_param$side = match.arg(column_names_side)[1]
    .Object@column_names_param$show = show_column_names
    .Object@column_names_param$gp = check_gp(column_names_gp)
    default_column_names_max_height = function() {
        min(unit.c(unit(6, "cm")), max_text_width(colnames(matrix), gp = .Object@column_names_param$gp))
    }
    .Object@column_names_param$max_height = column_names_max_height + unit(2, "mm")

    if(inherits(cluster_rows, "dendrogram") || inherits(cluster_rows, "hclust")) {
        .Object@row_dend_param$obj = cluster_rows
        .Object@row_dend_param$cluster = TRUE
    } else if(inherits(cluster_rows, "function")) {
        .Object@row_dend_param$fun = cluster_rows
        .Object@row_dend_param$cluster = TRUE
    } else {
        .Object@row_dend_param$cluster = cluster_rows
        if(!cluster_rows) {
            row_dend_width = unit(0, "mm")
            show_row_dend = FALSE
        }
    }
    if(!show_row_dend) {
        row_dend_width = unit(0, "mm")
    }
    .Object@row_dend_list = list()
    .Object@row_dend_param$distance = clustering_distance_rows
    .Object@row_dend_param$method = clustering_method_rows
    .Object@row_dend_param$side = match.arg(row_dend_side)[1]
    .Object@row_dend_param$width = row_dend_width + unit(1, "mm")  # append the gap
    .Object@row_dend_param$show = show_row_dend
    .Object@row_dend_param$gp = check_gp(row_dend_gp)
    .Object@row_dend_param$reorder = row_dend_reorder
    .Object@row_order_list = list() # default order
    if(is.null(row_order)) {
        .Object@row_order = seq_len(nrow(matrix))
    }  else {
        if(is.character(row_order)) {
            row_order = structure(seq_len(nrow(matrix)), names = rownames(matrix))[row_order]
        }
        .Object@row_order = row_order
    }

    if(inherits(cluster_columns, "dendrogram") || inherits(cluster_columns, "hclust")) {
        .Object@column_dend_param$obj = cluster_columns
        .Object@column_dend_param$cluster = TRUE
    } else if(inherits(cluster_columns, "function")) {
        .Object@column_dend_param$fun = cluster_columns
        .Object@column_dend_param$cluster = TRUE
    } else {
        .Object@column_dend_param$cluster = cluster_columns
        if(!cluster_columns) {
            column_dend_height = unit(0, "mm")
            show_column_dend = FALSE
        }
    }
    if(!show_column_dend) {
        column_dend_height = unit(0, "mm")
    }
    .Object@column_dend = NULL
    .Object@column_dend_param$distance = clustering_distance_columns
    .Object@column_dend_param$method = clustering_method_columns
    .Object@column_dend_param$side = match.arg(column_dend_side)[1]
    .Object@column_dend_param$height = column_dend_height + unit(1, "mm")  # append the gap
    .Object@column_dend_param$show = show_column_dend
    .Object@column_dend_param$gp = check_gp(column_dend_gp)
    .Object@column_dend_param$reorder = column_dend_reorder
    if(is.null(column_order)) {
        .Object@column_order = seq_len(ncol(matrix))
    } else {
        if(is.character(column_order)) {
            column_order = structure(seq_len(ncol(matrix)), names = colnames(matrix))[column_order]
        }
        .Object@column_order = column_order
    }

    .Object@top_annotation = top_annotation # a `HeatmapAnnotation` object
    if(is.null(top_annotation)) {
        .Object@top_annotation_param$height = unit(0, "mm")    
    } else {
        .Object@top_annotation_param$height = top_annotation_height + unit(1, "mm")  # append the gap
    }
    if(!is.null(top_annotation)) {
        if(length(top_annotation@anno_list) > 0) {
            if(!.Object@top_annotation@which == "column") {
                stop("`which` in `top_annotation` should only be `column`.")
            }
        }
    }
    
    .Object@bottom_annotation = bottom_annotation # a `HeatmapAnnotation` object
    if(is.null(bottom_annotation)) {
        .Object@bottom_annotation_param$height = unit(0, "mm")
    } else {
        .Object@bottom_annotation_param$height = bottom_annotation_height + unit(1, "mm")  # append the gap
    }
    if(!is.null(bottom_annotation)) {
        if(length(bottom_annotation@anno_list) > 0) {
            if(!.Object@bottom_annotation@which == "column") {
                stop("`which` in `bottom_annotation` should only be `column`.")
            }
        }
    }

    .Object@layout = list(
        layout_column_title_top_height = unit(0, "mm"),
        layout_column_dend_top_height = unit(0, "mm"),
        layout_column_anno_top_height = unit(0, "mm"),
        layout_column_names_top_height = unit(0, "mm"),
        layout_column_title_bottom_height = unit(0, "mm"),
        layout_column_dend_bottom_height = unit(0, "mm"),
        layout_column_anno_bottom_height = unit(0, "mm"),
        layout_column_names_bottom_height = unit(0, "mm"),

        layout_row_title_left_width = unit(0, "mm"),
        layout_row_dend_left_width = unit(0, "mm"),
        layout_row_names_left_width = unit(0, "mm"),
        layout_row_dend_right_width = unit(0, "mm"),
        layout_row_names_right_width = unit(0, "mm"),
        layout_row_title_right_width = unit(0, "mm"),

        layout_heatmap_width = width, # for the layout of heatmap list

        layout_index = matrix(nrow = 0, ncol = 2),
        graphic_fun_list = list()
    )

    return(.Object)

}

# == title
# Make cluster on columns
#
# == param
# -object a `Heatmap-class` object.
#
# == details
# The function will fill or adjust ``column_dend`` and ``column_order`` slots.
#
# This function is only for internal use.
#
# == value
# A `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "make_column_cluster",
    signature = "Heatmap",
    definition = function(object) {
    
    if(ht_global_opt("fast_hclust")) {
        hclust = fastcluster::hclust
    } else {
        hclust = stats::hclust
    }
    
    mat = object@matrix
    distance = object@column_dend_param$distance
    method = object@column_dend_param$method
    order = object@column_order
    reorder = object@column_dend_param$reorder

    if(object@column_dend_param$cluster) {
        if(!is.null(object@column_dend_param$obj)) {
            object@column_dend = object@column_dend_param$obj
        } else if(!is.null(object@column_dend_param$fun)) {
            object@column_dend = object@column_dend_param$fun(t(mat))
        } else {
            object@column_dend = hclust(get_dist(t(mat), distance), method = method)
        }
        column_order = get_dend_order(object@column_dend)  # we don't need the pre-defined orders

        if(inherits(object@column_dend, "hclust")) {
            object@column_dend = as.dendrogram(object@column_dend)
        }

        if(identical(reorder, NULL)) {
            if(is.numeric(mat)) {
                reorder = TRUE
            } else {
                reorder = FALSE
            }
        }

        do_reorder = TRUE
        if(identical(reorder, NA) || identical(reorder, FALSE)) {
            do_reorder = FALSE
        }
        if(identical(reorder, TRUE)) {
            do_reorder = TRUE
            reorder = colMeans(mat, na.rm = TRUE)
        }

        if(do_reorder) {
            if(length(reorder) != ncol(mat)) {
                stop("weight of reordering should have same length as number of columns.\n")
            }
            object@column_dend = reorder(object@column_dend, reorder)
            column_order = order.dendrogram(object@column_dend)
        }
    } else {
        column_order = order
    }

    # re-order
    object@column_order = column_order

    if(ncol(mat) != length(column_order)) {
        stop("Number of columns in the matrix are not the same as the length of\nthe cluster or the column order.")
    }

    return(object)
})


# == title
# Make cluster on rows
#
# == param
# -object a `Heatmap-class` object.
#
# == details
# The function will fill or adjust ``row_dend_list``, ``row_order_list``, ``row_title`` and ``matrix_param`` slots.
#
# If ``order`` is defined, no clustering will be applied.
#
# This function is only for internal use.
#
# == value
# A `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "make_row_cluster",
    signature = "Heatmap",
    definition = function(object) {

    if(ht_global_opt("fast_hclust")) {
        hclust = fastcluster::hclust
    } else {
        hclust = stats::hclust
    }

    mat = object@matrix
    distance = object@row_dend_param$distance
    method = object@row_dend_param$method
    order = object@row_order  # pre-defined row order
    km = object@matrix_param$km
    km_title = object@matrix_param$km_title
    split = object@matrix_param$split
    reorder = object@row_dend_param$reorder

    if(object@row_dend_param$cluster) {

        if(is.numeric(split) && length(split) == 1) {
            if(is.null(object@row_dend_param$obj)) {
                object@row_dend_param$obj = hclust(get_dist(mat, distance), method = method)
            }
        }

        if(!is.null(object@row_dend_param$obj)) {
            if(km > 1) {
                stop("You can not make k-means clustering since you have already specified a clustering object.")
            }

            if(inherits(object@row_dend_param$obj, "hclust")) {
                object@row_dend_param$obj = as.dendrogram(object@row_dend_param$obj)
            }

            if(is.null(split)) {
                object@row_dend_list = list(object@row_dend_param$obj)
                object@row_order_list = list(get_dend_order(object@row_dend_param$obj))
            } else {
                if(length(split) > 1 || !is.numeric(split)) {
                    stop("Since you specified a clustering object, you can only split rows by providing a number (number of row slices.")
                }
                if(split < 2) {
                    stop("Here `split` should be equal or larger than 2.")
                }
                
                object@row_dend_list = cut_dendrogram(object@row_dend_param$obj, split)
                sth = tapply(order.dendrogram(object@row_dend_param$obj), 
                    rep(seq_along(object@row_dend_list), times = sapply(object@row_dend_list, nobs)), 
                    function(x) x)
                attributes(sth) = NULL
                object@row_order_list = sth
            }

            if(identical(reorder, NULL)) {
                if(is.numeric(mat)) {
                    reorder = TRUE
                } else {
                    reorder = FALSE
                }
            }

            do_reorder = TRUE
            if(identical(reorder, NA) || identical(reorder, FALSE)) {
                do_reorder = FALSE
            }
            if(identical(reorder, TRUE)) {
                do_reorder = TRUE
                reorder = -rowMeans(mat, na.rm = TRUE)
            }

            if(do_reorder) {

                if(length(reorder) != nrow(mat)) {
                    stop("weight of reordering should have same length as number of rows.\n")
                }
                row_order_list = object@row_order_list
                row_dend_list = object@row_dend_list
                o_row_order_list = row_order_list
                for(i in seq_along(row_dend_list)) {
                    if(length(row_order_list[[i]]) > 1) {
                        sub_ind = which(seq_len(nrow(mat)) %in% o_row_order_list[[i]])
                        object@row_dend_list[[i]] = reorder(object@row_dend_list[[i]], reorder[sub_ind])
                        # object@row_order_list[[i]] = sub_ind[ order(order.dendrogram(object@row_dend_list[[i]])) ]
                        object@row_order_list[[i]] = order.dendrogram(object@row_dend_list[[i]])
                    }
                }
            }
            return(object)
        }

        row_order = seq_len(nrow(mat))
    } else {
        row_order = order
    }

    # make k-means clustering to add a split column
    if(km > 1 && is.numeric(mat)) {
        km.fit = kmeans(mat, centers = km)
        cluster = km.fit$cluster
        meanmat = lapply(unique(cluster), function(i) {
            colMeans(mat[cluster == i, , drop = FALSE])
        })
        meanmat = as.matrix(as.data.frame(meanmat))
        hc = hclust(dist(t(meanmat)))
        weight = colMeans(meanmat)
        hc = as.hclust(reorder(as.dendrogram(hc), -weight))
        cluster2 = numeric(length(cluster))
        for(i in seq_along(hc$order)) {
            cluster2[cluster == hc$order[i]] = i
        }
        cluster2 = factor(paste0("cluster", cluster2), levels = paste0("cluster", seq_along(hc$order)))
        cluster2 = factor(sprintf(km_title, cluster2), levels = sprintf(km_title, seq_along(hc$order)))

        if(is.null(split)) {
            split = data.frame(cluster2)
        } else if(is.matrix(split)) {
            split = as.data.frame(split)
            split = cbind(cluster2, split)
        } else if(is.null(ncol(split))) {
            split = data.frame(cluster2, split)
        } else {
            split = cbind(cluster2, split)
        }
            
    }

    # split the original order into a list according to split
    row_order_list = list()
    if(is.null(split)) {
        row_order_list[[1]] = row_order
    } else {
        if(is.null(ncol(split))) split = data.frame(split)
        if(is.matrix(split)) split = as.data.frame(split)

        for(i in seq_len(ncol(split))) {
            if(is.numeric(split[[i]])) {
                split[[i]] = factor(as.character(split[[i]]), levels = as.character(sort(unique(split[[i]]))))
            } else if(!is.factor(split[[i]])) {
                split[[i]] = factor(split[[i]])
            } else {
                # re-factor
                split[[i]] = factor(split[[i]], levels = intersect(levels(split[[i]]), unique(split[[i]])))
            }
        }

        split_name = NULL
        combined_name_fun = object@row_title_param$combined_name_fun
        if(!is.null(combined_name_fun)) {
            split_name = apply(as.matrix(split), 1, combined_name_fun)
        } else {
            split_name = apply(as.matrix(split), 1, paste, collapse = "\n")
        }

        row_order2 = do.call("order", split)
        row_level = unique(split_name[row_order2])
        for(k in seq_along(row_level)) {
            l = split_name == row_level[k]
            row_order_list[[k]] = intersect(row_order, which(l))
        }

        object@row_order_list = row_order_list

        if(!is.null(combined_name_fun)) {
            object@row_title = row_level
        }
    }
    o_row_order_list = row_order_list
    # make dend in each slice
    if(object@row_dend_param$cluster) {
        row_dend_list = rep(list(NULL), length(row_order_list))
        for(i in seq_along(row_order_list)) {
            submat = mat[ row_order_list[[i]], , drop = FALSE]
            if(nrow(submat) > 1) {
                if(!is.null(object@row_dend_param$fun)) {
                    row_dend_list[[i]] = object@row_dend_param$fun(mat)
                    row_order_list[[i]] = row_order_list[[i]][ get_dend_order(row_dend_list[[i]]) ]
                } else {
                    #if(is.numeric(mat)) {
                        row_dend_list[[i]] = hclust(get_dist(submat, distance), method = method)
                        row_order_list[[i]] = row_order_list[[i]][ get_dend_order(row_dend_list[[i]]) ]
                    #}
                }
            } else {
                #row_dend_list[[i]] = NULL
                row_order_list[[i]] = row_order_list[[i]][1]
            }
        }
        object@row_dend_list = row_dend_list

        for(i in seq_along(object@row_dend_list)) {
            if(inherits(object@row_dend_list[[i]], "hclust")) {
                object@row_dend_list[[i]] = as.dendrogram(object@row_dend_list[[i]])
            }
        }

        if(identical(reorder, NULL)) {
            if(is.numeric(mat)) {
                reorder = TRUE
            } else {
                reorder = FALSE
            }
        }

        do_reorder = TRUE
        if(identical(reorder, NA) || identical(reorder, FALSE)) {
            do_reorder = FALSE
        }
        if(identical(reorder, TRUE)) {
            do_reorder = TRUE
            reorder = -rowMeans(mat, na.rm = TRUE)
        }

        if(do_reorder) {

            if(length(reorder) != nrow(mat)) {
                stop("weight of reordering should have same length as number of rows.\n")
            }
            for(i in seq_along(row_dend_list)) {
                if(length(row_order_list[[i]]) > 1) {
                    sub_ind = which(seq_len(nrow(mat)) %in% o_row_order_list[[i]])
                    object@row_dend_list[[i]] = reorder(object@row_dend_list[[i]], reorder[sub_ind])
                    row_order_list[[i]] = sub_ind[ order.dendrogram(object@row_dend_list[[i]]) ]
                }
            }
        }
    }

    

    object@row_order_list = row_order_list
    object@matrix_param$split = split


    if(nrow(mat) != length(unlist(row_order_list))) {
        stop("Number of rows in the matrix are not the same as the length of\nthe cluster or the row orders.")
    }

    # adjust row_names_param$gp if the length of some elements is the same as row slices
    for(i in seq_along(object@row_names_param$gp)) {
        if(length(object@row_names_param$gp[[i]]) == length(object@row_order_list)) {
            gp_temp = NULL
            for(j in seq_along(object@row_order_list)) {
                gp_temp[ object@row_order_list[[j]] ] = object@row_names_param$gp[[i]][j]
            }
            object@row_names_param$gp[[i]] = gp_temp
        }
    }
    return(object)

})

# == title
# Make the layout of a single heatmap
#
# == param
# -object a `Heatmap-class` object.
# 
# == detail
# The layout of the single heatmap will be established by setting the size of each heatmap components.
# Also functions that make graphics for heatmap components will be recorded.
#
# Whether apply row clustering or column clustering affects the layout, so clustering should be applied 
# first before making the layout.
#
# This function is only for internal use.
#
# == value
# A `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "make_layout",
    signature = "Heatmap",
    definition = function(object) {

    # for components which are placed by rows, they will be splitted into parts
    # and slice_y controls the y-coordinates of each part

    # position of each row-slice
    gap = object@matrix_param$gap
    n_slice = length(object@row_order_list)
    if(length(gap) == 1) {
        gap = rep(gap, n_slice)
    } else if(length(gap) == n_slice - 1) {
        gap = unit.c(gap, unit(0, "mm"))
    } else if(length(gap) != n_slice) {
        stop("Length of `gap` should be 1 or number of row slices.")
    }

    snr = sapply(object@row_order_list, length)
    if(sum(snr)) {
        if(n_slice == 1) {
            slice_height = unit(1, "npc")*(snr/sum(snr))
            } else {
                slice_height = (unit(1, "npc") - sum(gap[seq_len(n_slice-1)]))*(snr/sum(snr))
            }  
        for(i in seq_len(n_slice)) {
            if(i == 1) {
                slice_y = unit(1, "npc")
            } else {
                slice_y = unit.c(slice_y, unit(1, "npc") - sum(slice_height[seq_len(i-1)]) - sum(gap[seq_len(i-1)]))
            }
        }

        ###########################################
        ## heatmap body
        object@layout$layout_index = rbind(c(5, 4))
        object@layout$graphic_fun_list = list(function(object) {
            for(i in seq_len(n_slice)) {
                draw_heatmap_body(object, k = i, y = slice_y[i], height = slice_height[i], just = c("center", "top"))
            }
        })
    }

    title_padding = unit(2.5, "mm")
    ############################################
    ## title on top or bottom
    column_title = object@column_title
    column_title_side = object@column_title_param$side
    column_title_gp = object@column_title_param$gp
    if(length(column_title) > 0) {
        if(column_title_side == "top") {
            if(object@column_title_rot %in% c(0, 180)) {
                object@layout$layout_column_title_top_height = grobHeight(textGrob(column_title, gp = column_title_gp)) + title_padding*2
            } else {
                object@layout$layout_column_title_top_height = grobWidth(textGrob(column_title, gp = column_title_gp)) + title_padding*2
            }
            object@layout$layout_index = rbind(object@layout$layout_index, c(1, 4))
        } else {
            if(object@column_title_rot %in% c(0, 180)) {
                object@layout$layout_column_title_bottom_height = grobHeight(textGrob(column_title, gp = column_title_gp)) + title_padding*2
            } else {
                object@layout$layout_column_title_bottom_height = grobWidth(textGrob(column_title, gp = column_title_gp)) + title_padding*2
            }
            object@layout$layout_index = rbind(object@layout$layout_index, c(9, 4))
        }
        object@layout$graphic_fun_list = c(object@layout$graphic_fun_list, function(object) draw_title(object, which = "column"))
    }

    ############################################
    ## title on left or right
    row_title = object@row_title
    row_title_side = object@row_title_param$side
    row_title_gp = object@row_title_param$gp
    if(length(row_title) > 0) {
        if(row_title_side == "left") {
            if(object@row_title_rot %in% c(0, 180)) {
                object@layout$layout_row_title_left_width = max_text_width(row_title, gp = row_title_gp) + title_padding*2
            } else {
                object@layout$layout_row_title_left_width = max_text_height(row_title, gp = row_title_gp) + title_padding*2
            }
            object@layout$layout_index = rbind(object@layout$layout_index, c(5, 1))
        } else {
            if(object@row_title_rot %in% c(0, 180)) {
                object@layout$layout_row_title_right_width = max_text_width(row_title, gp = row_title_gp) + title_padding*2
            } else {
                object@layout$layout_row_title_right_width = max_text_height(row_title, gp = row_title_gp) + title_padding*2
            }
            object@layout$layout_index = rbind(object@layout$layout_index, c(5, 7))
        }
        object@layout$graphic_fun_list = c(object@layout$graphic_fun_list, function(object) {
            for(i in seq_len(n_slice)) {
                draw_title(object, k = i, which = "row", y = slice_y[i], height = slice_height[i], just = c("center", "top"))
            }
        })
    }

    ##########################################
    ## dend on left or right
    show_row_dend = object@row_dend_param$show
    row_dend_side = object@row_dend_param$side
    row_dend_width = object@row_dend_param$width
    if(show_row_dend) {
        if(row_dend_side == "left") {
            object@layout$layout_row_dend_left_width = row_dend_width
            object@layout$layout_index = rbind(object@layout$layout_index, c(5, 2))
        } else {
            object@layout$layout_row_dend_right_width = row_dend_width
            object@layout$layout_index = rbind(object@layout$layout_index, c(5, 6))
        }
        #max_dend_height = max(sapply(object@row_dend_list, function(hc) attr(as.dendrogram(hc), "height")))
        object@layout$graphic_fun_list = c(object@layout$graphic_fun_list, function(object) {
            for(i in seq_len(n_slice)) {
                draw_dend(object, k = i, which = "row", y = slice_y[i], height = slice_height[i], just = c("center", "top"))
            }
        })
    }

    ##########################################
    ## dend on top or bottom
    show_column_dend = object@column_dend_param$show
    column_dend_side = object@column_dend_param$side
    column_dend_height = object@column_dend_param$height
    if(show_column_dend) {
        if(column_dend_side == "top") {
            object@layout$layout_column_dend_top_height = column_dend_height
            object@layout$layout_index = rbind(object@layout$layout_index, c(2, 4))
        } else {
            object@layout$layout_column_dend_bottom_height = column_dend_height
            object@layout$layout_index = rbind(object@layout$layout_index, c(8, 4))
        }
        object@layout$graphic_fun_list = c(object@layout$graphic_fun_list, function(object) draw_dend(object, which = "column"))
    }
    

    dimname_padding = unit(2, "mm")

    #######################################
    ## row_names on left or right
    row_names_side = object@row_names_param$side
    show_row_names = object@row_names_param$show
    row_names = rownames(object@matrix)
    row_names_gp = object@row_names_param$gp;
    if(show_row_names) {
        row_names_width = max(do.call("unit.c", lapply(seq_along(row_names), function(x) {
            cgp = subset_gp(row_names_gp, x)
            grobWidth(textGrob(row_names[x], gp = cgp))
        }))) + dimname_padding
        row_names_width = min(row_names_width, object@row_names_param$max_width)
        if(row_names_side == "left") {
            object@layout$layout_row_names_left_width = row_names_width
            object@layout$layout_index = rbind(object@layout$layout_index, c(5, 3))
        } else {
            object@layout$layout_row_names_right_width = row_names_width
            object@layout$layout_index = rbind(object@layout$layout_index, c(5, 5))
        }
        object@layout$graphic_fun_list = c(object@layout$graphic_fun_list, function(object) {
            for(i in seq_len(n_slice)) {
                draw_dimnames(object, k = i, which = "row", x = unit(0, "npc"), y = slice_y[i], height = slice_height[i], just = c("left", "top"), dimname_padding = dimname_padding)
            }
        })
    }

    #########################################
    ## column_names on top or bottom
    column_names_side = object@column_names_param$side
    show_column_names = object@column_names_param$show
    column_names = colnames(object@matrix)
    column_names_gp = object@column_names_param$gp
    if(show_column_names) {
        column_names_height = max(do.call("unit.c", lapply(seq_along(column_names), function(x) {
            cgp = subset_gp(column_names_gp, x)
            grobWidth(textGrob(column_names[x], gp = cgp))
        }))) + dimname_padding
        column_names_height = min(column_names_height, object@column_names_param$max_height)
        if(column_names_side == "top") {
            object@layout$layout_column_names_top_height = column_names_height
            object@layout$layout_index = rbind(object@layout$layout_index, c(4, 4))
        } else {
            object@layout$layout_column_names_bottom_height = column_names_height
            object@layout$layout_index = rbind(object@layout$layout_index, c(6, 4))
        }
        object@layout$graphic_fun_list = c(object@layout$graphic_fun_list, function(object) draw_dimnames(object, which = "column", y = unit(1, "npc"), just = c("center", "top"), dimname_padding = dimname_padding))
    }
    
    ##########################################
    ## annotation on top
    annotation = object@top_annotation
    annotation_height = object@top_annotation_param$height
    if(!is.null(annotation)) {
        if(length(annotation@anno_list) > 0) {
            object@layout$layout_column_anno_top_height = annotation_height
            object@layout$layout_index = rbind(object@layout$layout_index, c(3, 4))
            
            object@layout$graphic_fun_list = c(object@layout$graphic_fun_list, function(object) draw_annotation(object, which = "top"))
        }
    }

    ##########################################
    ## annotation on bottom
    annotation = object@bottom_annotation
    annotation_height = object@bottom_annotation_param$height
    if(!is.null(annotation)) {
        if(length(annotation@anno_list) > 0) {
            object@layout$layout_column_anno_bottom_height = annotation_height
            object@layout$layout_index = rbind(object@layout$layout_index, c(7, 4))
            object@layout$graphic_fun_list = c(object@layout$graphic_fun_list, function(object) draw_annotation(object, which = "bottom"))
        }
    }

    return(object)
})

# == title
# Draw the single heatmap with default parameters
#
# == param
# -object a `Heatmap-class` object.
#
# == details
# Actually it calls `draw,Heatmap-method`, but only with default parameters. If users want to customize the heatmap,
# they can pass parameters directly to `draw,Heatmap-method`.
#
# == value
# This function returns no value.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "show",
    signature = "Heatmap",
    definition = function(object) {

    # cat("A Heatmap object:\n")
    # cat("name:", object@name, "\n")
    # cat("dim:", nrow(object@matrix), "x", ncol(object@matrix), "\n")
    draw(object)
})

# == title
# Add heatmaps or row annotations as a heatmap list
#
# == param
# -object a `Heatmap-class` object.
# -x a `Heatmap-class` object, a `HeatmapAnnotation-class` object or a `HeatmapList-class` object.
#
# == details
# There is a shortcut function ``+.AdditiveUnit``.
#
# == value
# A `HeatmapList-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "add_heatmap",
    signature = "Heatmap",
    definition = function(object, x) {

    ht_list = new("HeatmapList")
    ht_list = add_heatmap(ht_list, object)
    ht_list = add_heatmap(ht_list, x)
    return(ht_list)

})

# == title
# Draw the heatmap body
#
# == param
# -object a `Heatmap-class` object.
# -k a matrix may be split by rows, the value identifies which row-slice.
# -... pass to `grid::viewport`, basically for defining the position of the viewport.
#
# == details
# The matrix can be split into several parts by rows if ``km`` or ``split`` is 
# specified when initializing the `Heatmap` object. If the matrix is split, 
# there will be gaps between rows to identify different row-slice.
#
# A viewport is created which contains subset rows of the heatmap.
#
# This function is only for internal use.
#
# == value
# This function returns no value.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "draw_heatmap_body",
    signature = "Heatmap",
    definition = function(object, k = 1, ...) {

    if(ncol(object@matrix) == 0) {
        return(invisible(NULL))
    }

    row_order = object@row_order_list[[k]]
    column_order = object@column_order

    gp = object@matrix_param$gp
    use_raster = object@heatmap_param$use_raster
    raster_device = object@heatmap_param$raster_device
    raster_quality = object@heatmap_param$raster_quality
    raster_device_param = object@heatmap_param$raster_device_param
    if(length(raster_device_param) == 0) raster_device_param = list()

    pushViewport(viewport(name = paste(object@name, "heatmap_body", k, sep = "_"), ...))

    mat = object@matrix[row_order, column_order, drop = FALSE]
    col_matrix = map_to_colors(object@matrix_color_mapping, mat)

    nc = ncol(mat)
    nr = nrow(mat)
    x = (seq_len(nc) - 0.5) / nc
    y = (rev(seq_len(nr)) - 0.5) / nr
    expand_index = expand.grid(seq_len(nr), seq_len(nc))
    
    cell_fun = object@matrix_param$cell_fun
    if(!is.null(cell_fun)) {
        use_raster = FALSE
    }
        
    if(use_raster) {
        # write the image into a temporary file and read it back
        device_info = switch(raster_device,
            png = c("grDevices", "png", "readPNG"),
            jpeg = c("grDevices", "jpeg", "readJPEG"),
            tiff = c("grDevices", "tiff", "readTIFF"),
            CairoPNG = c("Cairo", "png", "readPNG"),
            CairoJPEG = c("Cairo", "jpeg", "readJPEG"),
            CairoTIFF = c("Cairo", "tiff", "readTIFF")
        )
        if(!requireNamespace(device_info[1])) {
            stop(paste0("Need ", device_info[1], " package to output image."))
        }
        if(!requireNamespace(device_info[2])) {
            stop(paste0("Need ", device_info[2], " package to read image."))
        }
        # can we get the size of the heatmap body?
        heatmap_width = convertWidth(unit(1, "npc"), "bigpts", valueOnly = TRUE)
        heatmap_height = convertHeight(unit(1, "npc"), "bigpts", valueOnly = TRUE)
        if(heatmap_width <= 0 || heatmap_height <= 0) {
            stop("The width or height of the raster image is zero, maybe you forget to turn off the previous graphic device or it was corrupted. Run `dev.off()` to close it.")
        }
        
        temp_dir = tempdir()
                # dir.create(tmp_dir, showWarnings = FALSE)
        temp_image = tempfile(pattern = paste0(".heatmap_body_", object@name, "_", k, "_"), tmpdir = temp_dir, fileext = paste0(".", device_info[2]))
        #getFromNamespace(raster_device, ns = device_info[1])(temp_image, width = heatmap_width*raster_quality, height = heatmap_height*raster_quality)
        device_fun = getFromNamespace(raster_device, ns = device_info[1])
       
        ############################################
        ## make the heatmap body in a another process
        temp_R_data = tempfile(pattern = paste0(".heatmap_body_", object@name, "_", k, "_"), tmpdir = temp_dir, fileext = paste0(".RData"))
        temp_R_file = tempfile(pattern = paste0(".heatmap_body_", object@name, "_", k, "_"), tmpdir = temp_dir, fileext = paste0(".R"))
        if(Sys.info()["sysname"] == "Windows") {
            temp_image = gsub("\\\\", "/", temp_image)
            temp_R_data = gsub("\\\\", "/", temp_R_data)
            temp_R_file = gsub("\\\\", "/", temp_R_file)
        }
        save(device_fun, device_info, temp_image, heatmap_width, raster_quality, heatmap_height, raster_device_param,
            gp, x, expand_index, nc, nr, col_matrix, row_order, column_order, y,
            file = temp_R_data)
        R_cmd = qq("
        library(@{device_info[1]})
        library(grid)
        load('@{temp_R_data}')
        do.call('device_fun', c(list(filename = temp_image, width = max(c(heatmap_width*raster_quality, 1)), height = max(c(heatmap_height*raster_quality, 1))), raster_device_param))
        grid.rect(x[expand_index[[2]]], y[expand_index[[1]]], width = unit(1/nc, 'npc'), height = unit(1/nr, 'npc'), gp = do.call('gpar', c(list(fill = col_matrix), gp)))
        dev.off()
        q(save = 'no')
        ", code.pattern = "@\\{CODE\\}")
        writeLines(R_cmd, con = temp_R_file)
        if(grepl(" ", temp_R_file)) {
            if(is_windows()) {
                oe = try(system(qq("\"@{normalizePath(R_binary(), winslash='/')}\" --vanilla < \'@{temp_R_file}\'", code.pattern = "@\\{CODE\\}"), ignore.stdout = TRUE, ignore.stderr = TRUE, show.output.on.console = FALSE), silent = TRUE)
            } else {
                oe = try(system(qq("\"@{normalizePath(R_binary(), winslash='/')}\" --vanilla < \'@{temp_R_file}\'", code.pattern = "@\\{CODE\\}"), ignore.stdout = TRUE, ignore.stderr = TRUE), silent = TRUE)
            }
        } else {
            if(is_windows()) {
                oe = try(system(qq("\"@{normalizePath(R_binary(), winslash='/')}\" --vanilla < @{temp_R_file}", code.pattern = "@\\{CODE\\}"), ignore.stdout = TRUE, ignore.stderr = TRUE, show.output.on.console = FALSE), silent = TRUE)
            } else {
                oe = try(system(qq("\"@{normalizePath(R_binary(), winslash='/')}\" --vanilla < @{temp_R_file}", code.pattern = "@\\{CODE\\}"), ignore.stdout = TRUE, ignore.stderr = TRUE), silent = TRUE)
            }
        }
        ############################################
        file.remove(temp_R_data)
        file.remove(temp_R_file)
        if(inherits(oe, "try-error")) {
            stop(oe)
        }
        image = getFromNamespace(device_info[3], ns = device_info[2])(temp_image)
        image = as.raster(image)
        grid.raster(image, width = unit(1, "npc"), height = unit(1, "npc"))
        file.remove(temp_image)

    } else {
        if(any(names(gp) %in% c("type"))) {
            if(gp$type == "none") {
            } else {
                grid.rect(x[expand_index[[2]]], y[expand_index[[1]]], width = unit(1/nc, "npc"), height = unit(1/nr, "npc"), gp = do.call("gpar", c(list(fill = col_matrix), gp)))
            }
        } else {
            grid.rect(x[expand_index[[2]]], y[expand_index[[1]]], width = unit(1/nc, "npc"), height = unit(1/nr, "npc"), gp = do.call("gpar", c(list(fill = col_matrix), gp)))
        }

        if(is.function(cell_fun)) {
            for(i in row_order) {
                for(j in column_order) {
                    cell_fun(j, i, unit(x[which(column_order == j)], "npc"), unit(y[which(row_order == i)], "npc"), unit(1/nc, "npc"), unit(1/nr, "npc"), col_matrix[which(row_order == i), which(column_order == j)])
                }
            }
        }
    }

    upViewport()

})

is_windows = function() {
    tolower(.Platform$OS.type) == "windows"
}

R_binary = function() {
    R_exe = ifelse(is_windows(), "R.exe", "R")
    return(file.path(R.home("bin"), R_exe))
}

# == title
# Draw dendrogram on row or column
#
# == param
# -object a `Heatmap-class` object.
# -which is dendrogram put on the row or on the column of the heatmap?
# -k a matrix may be splitted by rows, the value identifies which row-slice.
# -max_height maximum height of the dendrograms.
# -... pass to `grid::viewport`, basically for defining the position of the viewport.
#
# == details
# If the matrix is split into several row slices, a list of dendrograms will be drawn by 
# the heatmap that each dendrogram corresponds to its row slices.
#
# A viewport is created which contains dendrograms.
#
# This function is only for internal use.
#
# == value
# This function returns no value.
#
# == seealso
# `grid.dendrogram`
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "draw_dend",
    signature = "Heatmap",
    definition = function(object,
    which = c("row", "column"), k = 1, max_height = NULL, ...) {

    which = match.arg(which)[1]
    
    side = switch(which,
        "row" = object@row_dend_param$side,
        "column" = object@column_dend_param$side)
    
    hc = switch(which,
        "row" = object@row_dend_list[[k]],
        "column" = object@column_dend)
    
    gp = switch(which,
        "row" = object@row_dend_param$gp,
        "column" = object@column_dend_param$gp)

    if(length(hc) == 0) {
        return(invisible(NULL))
    }

    if(is.null(hc)) return(invisible(NULL))

    dend = as.dendrogram(hc)
    n = length(labels(dend))
    if(nobs(dend) <= 1) {
        return(invisible(NULL))
    }

    dend_padding = unit(1, "mm")
    pushViewport(viewport(name = paste(object@name, which, "cluster", k, sep = "_"), ...))

    if(side == "left") {
        grid.dendrogram(dend, name = paste(object@name, "dend_row", k, sep = "_"), max_height = max_height, facing = "right", order = "reverse", x = dend_padding, width = unit(1, "npc") - dend_padding*2, just = "left")
    } else if(side == "right") {
        grid.dendrogram(dend, name = paste(object@name, "dend_row", k, sep = "_"), max_height = max_height, facing = "left", order = "reverse", x = unit(0, "mm"), width = unit(1, "npc") - dend_padding*2, just = "left")
    } else if(side == "top") {
        grid.dendrogram(dend, name = paste(object@name, "dend_column", sep = "_"), max_height = max_height, facing = "bottom", y = dend_padding, height = unit(1, "npc") - dend_padding*2, just = "bottom")
    } else if(side == "bottom") {
        grid.dendrogram(dend, name = paste(object@name, "dend_column", sep = "_"), max_height = max_height, facing = "top", y = dend_padding, height = unit(1, "npc") - dend_padding*2, just = "bottom")
    } 

    upViewport()

})

# == title
# Draw row names or column names
#
# == param
# -object a `Heatmap-class` object.
# -which are names put on the row or on the column of the heatmap?
# -k a matrix may be split by rows, the value identifies which row-slice.
# -dimname_padding padding for the row/column names
# -... pass to `grid::viewport`, basically for defining the position of the viewport.
#
# == details
# A viewport is created which contains row names or column names.
#
# This function is only for internal use.
#
# == value
# This function returns no value.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "draw_dimnames",
    signature = "Heatmap",
    definition = function(object,
    which = c("row", "column"), k = 1, dimname_padding = unit(0, "mm"), ...) {

    which = match.arg(which)[1]

    side = switch(which,
        "row" = object@row_names_param$side,
        "column" = object@column_names_param$side)

    nm = switch(which,
        "row" = rownames(object@matrix)[ object@row_order_list[[k]] ],
        "column" = colnames(object@matrix)[ object@column_order ]
    )
    
    gp = switch(which,
        "row" = subset_gp(object@row_names_param$gp, object@row_order_list[[k]]),
        "column" = subset_gp(object@column_names_param$gp, object@column_order)
    )

    if(is.null(nm)) {
        return(invisible(NULL))
    }

    n = length(nm)
    
    if(which == "row") {
        pushViewport(viewport(name = paste(object@name, "row_names", k, sep = "_"), ...))
        if(side == "left") {
            x = unit(1, "npc") - dimname_padding
            just = c("right", "center")
        } else {
            x = unit(0, "npc") + dimname_padding
            just = c("left", "center")
        }
        y = (rev(seq_len(n)) - 0.5) / n
        grid.text(nm, x, y, just = just, gp = gp)
    } else {
        pushViewport(viewport(name = paste(object@name, "column_names", sep = "_"), ...))
        x = (seq_len(n) - 0.5) / n
        if(side == "top") {
            y = unit(0, "npc") + dimname_padding
            just = c("left", "center")
        } else {
            y = unit(1, "npc") - dimname_padding
            just = c("right", "center")
        }
        grid.text(nm, x, y, rot = 90, just = just, gp = gp)
    }

    upViewport()
})

# == title
# Draw heatmap title
#
# == param
# -object a `Heatmap-class` object.
# -which is title put on the row or on the column of the heatmap?
# -k a matrix may be split by rows, the value identifies which row-slice.
# -... pass to `grid::viewport`, basically for defining the position of the viewport.
#
# == details
# A viewport is created which contains heatmap title.
#
# This function is only for internal use.
#
# == value
# This function returns no value.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "draw_title",
    signature = "Heatmap",
    definition = function(object,
    which = c("row", "column"), k = 1, ...) {

    which = match.arg(which)[1]

    side = switch(which,
        "row" = object@row_title_param$side,
        "column" = object@column_title_param$side)

    gp = switch(which,
        "row" = object@row_title_param$gp,
        "column" = object@column_title_param$gp)
    
    if(which == "row") {
        gp = subset_gp(gp, k)
    }
    
    title = switch(which,
        "row" = object@row_title[k],
        "column" = object@column_title)

    rot = switch(which,
        "row" = object@row_title_rot,
        "column" = object@column_title_rot)

    just = switch(which, 
        "row" = object@row_title_just,
        "column" = object@column_title_just)

    title_padding = unit(2.5, "mm")

    if(which == "row") {
        
        pushViewport(viewport(name = paste(object@name, "row_title", k, sep = "_"), clip = FALSE, ...))
        if(side == "left") {
            grid.text(title, x = unit(1, "npc") - title_padding, rot = rot, just = just, gp = gp)
        } else {
            grid.text(title, x = title_padding, rot = rot, just = just, gp = gp)
        }
        upViewport()
    } else {
        pushViewport(viewport(name = paste(object@name, "column_title", sep = "_"), clip = FALSE, ...))
        if(side == "top") {
            grid.text(title, y = title_padding, rot = rot, just = just, gp = gp)
        } else {
            grid.text(title, y = unit(1, "npc") - title_padding, rot = rot, just = just, gp = gp)
        }
        upViewport()
    }
})

# == title
# Draw column annotations
#
# == param
# -object a `Heatmap-class` object.
# -which are the annotations put on the top or bottom of the heatmap?
#
# == details
# A viewport is created which contains column annotations.
#
# Since the column annotations is a `HeatmapAnnotation-class` object, the function
# calls `draw,HeatmapAnnotation-method` to draw the annotations.
#
# This function is only for internal use.
#
# == value
# This function returns no value.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "draw_annotation",
    signature = "Heatmap",
    definition = function(object, which = c("top", "bottom")) {
    
    which = match.arg(which)[1]

    annotation = switch(which,
        top = object@top_annotation,
        bottom = object@bottom_annotation)

    # if there is no annotation, draw nothing
    if(is.null(annotation)) {
        return(invisible(NULL))
    }

    padding = unit(1, "mm")
    if(which == "top") {
        draw(annotation, index = object@column_order, y = padding, height = unit(1, "npc") - padding, just = "bottom")
    } else {
        draw(annotation, index = object@column_order, y = unit(0, "mm"), height = unit(1, "npc") - padding, just = "bottom", align_to = "top")
    }
})

# == title
# Width of each heatmap component
#
# == param
# -object a `Heatmap-class` object.
# -k which component in the heatmap, see `Heatmap-class`.
#
# == details
#
# This function is only for internal use.
#
# == value
# A `grid::unit` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "component_width",
    signature = "Heatmap",
    definition = function(object, k = 1:7) {

    .single_unit = function(k) {
        if(k == 1) {
            object@layout$layout_row_title_left_width
        } else if(k == 2) {
            object@layout$layout_row_dend_left_width
        } else if(k == 3) {
            object@layout$layout_row_names_left_width
        } else if(k == 4) {
            if(ncol(object@matrix) == 0) {
                unit(0, "mm")
            } else {
                if(!is.unit(object@heatmap_param$width)) {
                    unit(1, "null")
                } else {
                    object@heatmap_param$width
                }
            }
        } else if(k == 5) {
            object@layout$layout_row_names_right_width
        } else if(k == 6) {
            object@layout$layout_row_dend_right_width
        } else if(k == 7) {
            object@layout$layout_row_title_right_width
        } else {
            stop("wrong 'k'")
        }
    }

    do.call("unit.c", lapply(k, function(i) .single_unit(i)))
})

# == title
# Height of each heatmap component
#
# == param
# -object a `Heatmap-class` object.
# -k which component in the heatmap, see `Heatmap-class`.
#
# == detail
#
# This function is only for internal use.
#
# == value
# A `grid::unit` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "component_height",
    signature = "Heatmap",
    definition = function(object, k = 1:9) {

    .single_unit = function(k) {
        if(k == 1) {
            object@layout$layout_column_title_top_height
        } else if(k == 2) {
            object@layout$layout_column_dend_top_height
        } else if(k == 3) {
            object@layout$layout_column_anno_top_height
        } else if(k == 4) {
            object@layout$layout_column_names_top_height
        } else if(k == 5) {
            unit(1, "null")
        } else if(k == 6) {
            object@layout$layout_column_names_bottom_height
        } else if(k == 7) {
            object@layout$layout_column_anno_bottom_height
        } else if(k == 8) {
            object@layout$layout_column_dend_bottom_height
        } else if(k == 9) {
            object@layout$layout_column_title_bottom_height
        } else {
            stop("wrong 'k'")
        }
    }

    do.call("unit.c", lapply(k, function(i) .single_unit(i)))
})

# == title
# Set height of each heatmap component
#
# == param
# -object a `Heatmap-class` object.
# -k which components, see `Heatmap-class`.
# -v height of the component, a `grid::unit` object.
#
# == detail
#
# This function is only for internal use.
#
# == value
# This function returns no value.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "set_component_height",
    signature = "Heatmap",
    definition = function(object, k, v) {

    if(k == 1) {
        object@layout$layout_column_title_top_height = v
    } else if(k == 2) {
        object@layout$layout_column_dend_top_height = v
    } else if(k == 3) {
        object@layout$layout_column_anno_top_height = v
    } else if(k == 4) {
        object@layout$layout_column_names_top_height = v
    } else if(k == 6) {
        object@layout$layout_column_names_bottom_height = v
    } else if(k == 7) {
        object@layout$layout_column_anno_bottom_height = v
    } else if(k == 8) {
        object@layout$layout_column_dend_bottom_height = v
    } else if(k == 9) {
        object@layout$layout_column_title_bottom_height = v
    } else {
        stop("wrong 'k'")
    }

    return(object)
})

# == title
# Draw a single heatmap
#
# == param
# -object a `Heatmap-class` object.
# -internal only used inside the calling of `draw,HeatmapList-method`. Only heatmap without legends will be drawn.
# -test only for testing
# -... pass to `draw,HeatmapList-method`.
#
# == detail
# The function creates a `HeatmapList-class` object which only contains a single heatmap
# and call `draw,HeatmapList-method` to make the final heatmap.
#
# == value
# This function returns no value.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "draw",
    signature = "Heatmap",
    definition = function(object, internal = FALSE, test = FALSE, ...) {

    if(test) {
        object = prepare(object)
        grid.newpage()
        draw(object, internal = TRUE)
    } else {
        if(internal) {  # a heatmap without legend
            layout = grid.layout(nrow = 9, ncol = 7, widths = component_width(object, 1:7), 
                heights = component_height(object, 1:9))
            pushViewport(viewport(layout = layout))
            ht_layout_index = object@layout$layout_index
            ht_graphic_fun_list = object@layout$graphic_fun_list
            for(j in seq_len(nrow(ht_layout_index))) {
                if(ht_layout_index[j, 1] == 5 && ht_layout_index[j, 2] == 4) {
                    pushViewport(viewport(layout.pos.row = ht_layout_index[j, 1], layout.pos.col = ht_layout_index[j, 2], name = paste(object@name, "heatmap_body_wrap", sep = "_")))
                } else {
                    pushViewport(viewport(layout.pos.row = ht_layout_index[j, 1], layout.pos.col = ht_layout_index[j, 2]))
                }
                ht_graphic_fun_list[[j]](object)
                upViewport()
            }
            upViewport()
        } else {
            if(ncol(object@matrix) == 0) {
                stop("Single heatmap should contains a matrix with at least one column.\nZero-column matrix can only be appended to the heatmap list.")
            }
            ht_list = new("HeatmapList")
            ht_list = add_heatmap(ht_list, object)
            draw(ht_list, ...)
        }
    }
})

# == title
# Prepare the heatmap
#
# == param
# -object a `Heatmap-class` object.
# -process_rows whether process rows of the heatmap
#
# == detail
# The preparation of the heatmap includes following steps:
#
# - making clustering on rows if specified (by calling `make_row_cluster,Heatmap-method`)
# - making clustering on columns if specified (by calling `make_column_cluster,Heatmap-method`)
# - making the layout of the heatmap (by calling `make_layout,Heatmap-method`)
#
# This function is only for internal use.
#
# == value
# A `Heatmap-class` object.
#
# == author
# Zuguang Gu <z.gu@dkfz.de>
#
setMethod(f = "prepare",
    signature = "Heatmap",
    definition = function(object, process_rows = TRUE) {

    if(process_rows) {
        object = make_row_cluster(object)
    }
    if(object@column_dend_param$cluster) object = make_column_cluster(object)

    object = make_layout(object)
    return(object)

})
